# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List, Tuple, Union

import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

import secretflow as sf
from secretflow.device import PYU, PYUObject, proxy
from secretflow.ml.nn.callbacks.attack import AttackCallback
from secretflow.utils.communicate import ForwardData


@proxy(PYUObject)
class ExploitAttackTrainer(object):
    def __init__(self, surrogate_model, H_y, y_pri) -> None:
        self.surrogate_model = surrogate_model()
        self.H_y = H_y
        self.y_pri = y_pri

    @staticmethod
    def BCELoss(y_true, y_pred):
        binary_crossentropy = tf.keras.losses.binary_crossentropy(y_true, y_pred)
        return binary_crossentropy

    def GradLoss(self, grad_target, X, y_true):
        with tf.GradientTape() as tape:
            tape.watch(X)
            y_pred = self.surrogate_model(X)
            y_pred = tf.reshape(y_pred, shape=(-1,))
            binary_crossentropy = self.BCELoss(y_true, y_pred)
            loss = tf.reduce_mean(binary_crossentropy)
        gradients = tape.gradient(loss, X)
        gradients = tf.cast(gradients, dtype=tf.float32)
        gradient_loss_value = tf.norm(
            (grad_target - gradients * y_true.shape[0]), ord=1
        )
        return gradient_loss_value

    @staticmethod
    def KLLoss(y_pri, y_pred):
        eps = 1e-9
        y_pred_mean = tf.reduce_mean(
            tf.stack([1 - y_pred, y_pred], axis=1), axis=0, keepdims=True
        )
        kl_divergence = -y_pri * tf.math.log(y_pred_mean / (y_pri + eps) + eps)
        kl_divergence = tf.reduce_sum(kl_divergence)
        return kl_divergence

    def custom_loss(
        self,
        X,
        y_true,
        y_pred,
        y_pri,
        H_y,
        grad_target,
        alpha_acc=1,
        alpha_grad=0.001,
        alpha_kl=0.001,
    ):
        H_y = tf.cast(H_y, dtype=tf.float32)
        y_pri = tf.cast(y_pri, dtype=tf.float32)
        loss_acc = self.BCELoss(y_true, y_pred) / H_y
        loss_grad = self.GradLoss(grad_target, X, y_true)
        loss_kl = self.KLLoss(y_pri, y_pred)
        # total loss, weight could be adjust by need
        total_loss = alpha_acc * loss_acc + alpha_grad * loss_grad + alpha_kl * loss_kl
        return total_loss

    def get_dataloader(
        self,
        x,
        p_hats,
        batch_size=128,
        mode="train",
        grads=None,
    ):
        if mode == "train":
            x_train = x
            y_train = np.squeeze(np.round(p_hats))
            x_train = tf.cast(x_train, dtype=tf.float32)
            y_train = tf.cast(y_train, dtype=tf.float32)
            # p_hats = tf.cast(p_hats, dtype=tf.float32)
            grads = tf.cast(grads, dtype=tf.float32)
            dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, grads))
            dataloader = dataset.batch(batch_size)
        elif mode == "eval":
            X_test = x
            y_test = np.squeeze(np.round(p_hats))

            X_test = tf.cast(X_test, dtype=tf.float32)
            y_test = tf.cast(y_test, dtype=tf.float32)
            dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
            dataloader = dataset.batch(batch_size)
        return dataloader

    def train(
        self,
        x_train,
        p_hats,
        grads,
        epochs,
        batch_size,
        alpha_acc=1,
        alpha_grad=0.001,
        alpha_kl=0.001,
    ):
        # Optimizer
        optimizer = Adam(learning_rate=0.01)

        # Dataloader
        dataloader = self.get_dataloader(
            x=x_train, p_hats=p_hats, grads=grads, batch_size=batch_size, mode="train"
        )

        # start training
        for epoch in range(epochs):
            total_loss = 0.0
            correct_predictions = 0
            total_samples = 0

            y_true_list = []
            y_pred_list = []
            for X, y, grad_target in dataloader:
                with tf.GradientTape() as tape:
                    y_pred = self.surrogate_model(X)
                    y_pred = tf.reshape(y_pred, shape=(-1,))
                    loss = self.custom_loss(
                        X,
                        y,
                        y_pred,
                        self.y_pri,
                        self.H_y,
                        grad_target,
                        alpha_acc=alpha_acc,
                        alpha_grad=alpha_grad,
                        alpha_kl=alpha_kl,
                    )

                gradients = tape.gradient(
                    loss, self.surrogate_model.trainable_variables
                )
                optimizer.apply_gradients(
                    zip(gradients, self.surrogate_model.trainable_variables)
                )
                total_loss += loss.numpy()

                y_pred_list.append(y_pred)
                predicted_labels = tf.round(y_pred)  # round to 0 or 1

                y = tf.cast(y, dtype=tf.float32)

                y_true_list.append(y)
                # correct_predictions += tf.reduce_sum(tf.cast(tf.equal(predicted_labels, y), tf.int32))
                correct_predictions += tf.reduce_sum(
                    tf.cast(predicted_labels == y, tf.int32)
                )

                total_samples += X.shape[0]
            average_loss = total_loss / total_samples

            auc_metric = tf.keras.metrics.AUC()
            auc_metric.update_state(y_true_list, y_pred_list)
            auc_value = auc_metric.result().numpy()
            accuracy = (correct_predictions / total_samples).numpy()

            logging.info(
                "==================================================================================================="
            )
            logging.info(
                f"Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}, AUC: {auc_value:.4f}"
            )
            logging.info(
                "==================================================================================================="
            )

    def evaluate(
        self,
        x_test,
        p_hats,
        batch_size=128,
    ):
        # Dataloader
        dataloader = self.get_dataloader(
            x=x_test,
            p_hats=p_hats,
            batch_size=batch_size,
            mode="eval",
        )

        correct_predictions = 0
        total_samples = 0
        y_true_list = []
        y_pred_list = []
        for X, y in dataloader:
            y_pred = self.surrogate_model(X)
            y_pred = tf.reshape(y_pred, shape=(-1,))
            y_pred_list.append(y_pred)
            predicted_labels = tf.round(y_pred)  # round to 0 or 1
            y = tf.cast(y, dtype=tf.float32)
            y_true_list.append(y)
            correct_predictions += tf.reduce_sum(
                tf.cast(predicted_labels == y, tf.int32)
            )
            total_samples += X.shape[0]
        # compute auc
        auc_metric = tf.keras.metrics.AUC()

        y_true_list = np.concatenate(y_true_list)
        y_pred_list = np.concatenate(y_pred_list)

        auc_metric.update_state(y_true_list, y_pred_list)
        auc_value = auc_metric.result().numpy()
        accuracy = correct_predictions / total_samples
        logging.info(f"Test Accuracy: {accuracy:.4f}, AUC: {auc_value:.4f}")
        metrics = {"accuracy": float(accuracy), "auc_value": float(auc_value)}
        return metrics


class ExploitAttack(AttackCallback):
    def __init__(
        self,
        attack_party: PYU,
        surrogate_model=None,
        epoch=5,
        batch_size=128,
        H_y=None,
        y_pri=None,
        alpha_acc=1,
        alpha_grad=0.001,
        alpha_kl=0.001,
        **params,
    ):
        """
        ExplioitAttack in tensorflow backend.

        Attributes:
            attack_party (PYU): An instance representing the attacking party, which contains the information or methods required to conduct the attack.
            surrogate_model: A machine learning model that the attacker aims to exploit. This model is supposed to mimic the target model's behavior and is used to generate adversarial examples or identify vulnerabilities. If not provided, the attack will operate without a specific model.
            epoch (int): The number of epochs for which the attack will be conducted. Default is 5.
            batch_size (int): The size of the batch to be used during the attack generation process. Default is 128.
            H_y: A heuristic or auxiliary information used by the attack algorithm, such as gradients or model outputs. This parameter is model-dependent and should be provided if relevant.
            y_pri: The adversary's prior knowledge about the outputs. It could be used in attacks that require knowledge about the distribution of the labels.
            alpha_acc: weight of accuracy loss in custom loss, default 1.
            alpha_grad: weight of grad loss in custom loss, default 0.001.
            alpha_kl: weight of kl loss in custom loss, default 0.001.
            **params: Arbitrary keyword arguments that could be used to pass additional parameters required for the attack.

        Note:
        This class is intended to be used as a callback within a training loop, where it will be invoked at specified intervals to conduct the attack.
        """
        super().__init__()
        self.grads = []
        self.p_hats = []
        self.eval_p_hats = []
        self.attack_party = attack_party
        self.epoch = epoch
        self.batch_size = batch_size
        self.agg_hiddens = []
        self.eval_hiddens = []
        self.attacker = ExploitAttackTrainer(
            device=attack_party,
            surrogate_model=surrogate_model,
            H_y=H_y,
            y_pri=y_pri,
        )
        self.alpha_acc = alpha_acc
        self.alpha_kl = alpha_kl
        self.alpha_grad = alpha_grad

    @staticmethod
    def convert_to_ndarray(*data: List) -> Union[List[jnp.ndarray], jnp.ndarray]:
        def _convert_to_ndarray(hidden):
            # processing data
            if not isinstance(hidden, jnp.ndarray):
                if isinstance(hidden, tf.Tensor):
                    hidden = jnp.array(hidden.numpy())
                if isinstance(hidden, np.ndarray):
                    hidden = jnp.array(hidden)
            return hidden

        if isinstance(data, Tuple) and len(data) == 1:
            # The case is after packing and unpacking using PYU, a tuple of length 1 will be obtained, if 'num_return' is not specified to PYU.
            data = data[0]
        if isinstance(data, (List, Tuple)):
            return [_convert_to_ndarray(d) for d in data]
        else:
            return _convert_to_ndarray(data)

    @staticmethod
    def convert_to_tensor(hidden: Union[List, Tuple], backend: str):
        if backend == "tensorflow":
            if isinstance(hidden, (List, Tuple)):
                hidden = [tf.convert_to_tensor(d) for d in hidden]

            else:
                hidden = tf.convert_to_tensor(hidden)

        return hidden

    def on_train_begin(self, logs=None):
        self.steps_per_epoch = sf.reveal(
            self._workers[self.device_y].get_steps_per_epoch()
        )

    def on_agglayer_forward_end(self, hiddens=None):
        victim_status = sf.reveal(self._workers[self.device_y].get_traing_status())
        epoch = victim_status["epoch"]
        stage = victim_status["stage"]
        slmodel_epochs = self.params["epochs"]
        if epoch == slmodel_epochs - 1:
            y_pred = self._workers[self.device_y].predict(hiddens)
            p_hat = self.device_y(self.convert_to_ndarray)(y_pred)
            if stage == "train":
                self.p_hats.append(p_hat.to(self.attack_party))
            else:
                self.eval_p_hats.append(p_hat.to(self.attack_party))

            if isinstance(hiddens, PYUObject):
                hiddens = [hiddens]
            agg_obj = hiddens[0].to(self.attack_party)
            agg_hidden = self.attack_party(
                lambda forward_data: (
                    forward_data.hidden
                    if isinstance(forward_data, ForwardData)
                    else forward_data
                )
            )(agg_obj)

            if stage == "train":
                self.agg_hiddens.append(agg_hidden)
            else:
                self.eval_hiddens.append(agg_hidden)

    def on_agglayer_backward_end(self, gradients=None):
        victim_status = sf.reveal(self._workers[self.device_y].get_traing_status())
        epoch = victim_status["epoch"]
        slmodel_epochs = self.params["epochs"]
        if epoch == slmodel_epochs - 1:
            grad_client = gradients[self.attack_party]
            grad_client_np = self.attack_party(self.convert_to_ndarray)(grad_client)

            self.grads.append(grad_client_np.to(self.attack_party))

    def on_train_end(self, logs=None):
        self.p_hats = self.attack_party(lambda p_hats: np.concatenate(p_hats, axis=0))(
            self.p_hats
        )

        self.grads = self.attack_party(lambda grads: np.concatenate(grads, axis=0))(
            self.grads
        )
        self.agg_hiddens = self.attack_party(
            lambda agg_hiddens: np.concatenate(agg_hiddens, axis=0)
        )(self.agg_hiddens)

        self.attacker.train(
            x_train=self.agg_hiddens,  # agg_hiddens
            p_hats=self.p_hats,
            grads=self.grads,
            epochs=self.epoch,
            batch_size=self.batch_size,
            alpha_acc=self.alpha_acc,
            alpha_kl=self.alpha_kl,
            alpha_grad=self.alpha_grad,
        )
        logging.info("on_train end")

    def get_attack_metrics(self):
        self.eval_hiddens = self.attack_party(
            lambda agg_hiddens: np.concatenate(agg_hiddens, axis=0)
        )(self.eval_hiddens)
        self.eval_p_hats = self.attack_party(
            lambda p_hats: np.concatenate(p_hats, axis=0)
        )(self.eval_p_hats)
        metrics = self.attacker.evaluate(
            x_test=self.eval_hiddens,
            p_hats=self.eval_p_hats,
            batch_size=128,
        )
        return sf.reveal(metrics)
